{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "89cc7a4c-6155-4021-b73a-58ab9ae93789",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2024-10-20T04:11:12.505250Z",
     "iopub.status.busy": "2024-10-20T04:11:12.504959Z",
     "iopub.status.idle": "2024-10-20T04:11:12.509900Z",
     "shell.execute_reply": "2024-10-20T04:11:12.508505Z",
     "shell.execute_reply.started": "2024-10-20T04:11:12.505229Z"
    }
   },
   "outputs": [],
   "source": [
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "4a8926ae-7472-47a5-bd47-217397d96b37",
   "metadata": {},
   "source": [
    "$$\n",
    "\\begin{split}\n",
    "&-\\sum_i^Cy_i\\log p_i\\\\\n",
    "\\end{split}\n",
    "$$\n",
    "\n",
    "- 对于 imagenet 的 1000 分类问题\n",
    "\n",
    "$$\n",
    "-\\sum_{i=1}^{1000}y_i\\log p_i\n",
    "$$"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "44ebbd0d-881d-4846-b5e8-604e804f20a4",
   "metadata": {},
   "source": [
    "- input_ids => logits\n",
    "- logits, 要与 labels 等 shape；\n",
    "    - 出于对齐 shape 的目的，至少在 seqlen 要保持一致；\n",
    "- batch 进去算，默认也是平均意义的（即单样本级别的 loss）；\n",
    "    - https://pytorch.org/docs/stable/generated/torch.nn.CrossEntropyLoss.html\n",
    "    - 默认 reduction = 'mean'\n",
    "- https://www.bilibili.com/video/BV1NY4y1E76o/"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "2cf44faa-35ad-40a5-a864-0fb813da6bd7",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2024-10-20T04:10:03.072233Z",
     "iopub.status.busy": "2024-10-20T04:10:03.071914Z",
     "iopub.status.idle": "2024-10-20T04:10:03.080631Z",
     "shell.execute_reply": "2024-10-20T04:10:03.079827Z",
     "shell.execute_reply.started": "2024-10-20T04:10:03.072215Z"
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<torch._C.Generator at 0x724a0d6039f0>"
      ]
     },
     "execution_count": 2,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import torch\n",
    "torch.manual_seed(42)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "3b4678d7-68f2-4ce7-b03e-72f2af5c8c1a",
   "metadata": {},
   "outputs": [],
   "source": [
    "# 模拟输入和标签\n",
    "input_ids = torch.tensor([[1, 2, 3, 4], [4, 3, 2, 1]])\n",
    "labels = torch.tensor([[2, 3, 4, -100], [3, 2, 1, -100]])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "91d287bd-e573-404c-8420-f8ebde5c193e",
   "metadata": {},
   "outputs": [],
   "source": [
    "# 模拟模型输出 (batch_size, sequence_length, num_classes)\n",
    "logits = torch.randn(2, 4, 10)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "dd878f62-ce4a-4525-893e-d7d59ab6e056",
   "metadata": {},
   "outputs": [],
   "source": [
    "# 初始化损失函数并设置 ignore_index\n",
    "criterion = nn.CrossEntropyLoss(ignore_index=-100)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "id": "2d202f20-8962-42d8-bace-c5fc23161c0f",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([8, 10])"
      ]
     },
     "execution_count": 14,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "logits.view(-1, logits.size(-1)).shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "id": "684b318c-84fb-4811-9c89-42078d7b1145",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(tensor([   2,    3,    4, -100,    3,    2,    1, -100]), torch.Size([8]))"
      ]
     },
     "execution_count": 20,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "labels.view(-1), labels.view(-1).shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "id": "bb98c0f8-6bcb-4333-9f13-f2798cc220a6",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor(2.5207)"
      ]
     },
     "execution_count": 21,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# 计算损失\n",
    "loss = criterion(logits.view(-1, logits.size(-1)), labels.view(-1))\n",
    "loss"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "id": "73762327-ba73-44ee-b2a5-2352ea94d550",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[-1.2072, -1.6468, -2.2334, -5.2397, -2.4557, -4.3687, -3.1772, -4.7388,\n",
       "         -3.8863, -1.4854],\n",
       "        [-2.9390, -3.9501, -3.2744, -3.1059, -3.3153, -1.7841, -0.9042, -2.7061,\n",
       "         -3.0439, -2.1069],\n",
       "        [-4.0058, -2.1694, -2.4469, -1.5671, -1.9685, -1.9512, -2.6372, -1.9129,\n",
       "         -3.4793, -3.2059],\n",
       "        [-2.8926, -1.7811, -4.0257, -3.5122, -2.8644, -0.9236, -2.3221, -3.0655,\n",
       "         -2.3353, -3.4156],\n",
       "        [-4.2205, -1.6673, -3.5427, -3.2640, -3.9371, -0.5401, -3.8976, -3.1508,\n",
       "         -3.5767, -3.3210],\n",
       "        [-2.3889, -1.9412, -2.9550, -1.2756, -3.2810, -3.2030, -3.8702, -2.4310,\n",
       "         -2.5305, -1.7914],\n",
       "        [-3.5394, -1.5970, -4.6261, -2.0580, -1.9965, -2.5852, -1.2235, -2.9184,\n",
       "         -3.0949, -3.6389],\n",
       "        [-3.5348, -1.2022, -2.6524, -1.9564, -2.4236, -2.0539, -1.9052, -3.1219,\n",
       "         -4.6866, -3.2310]])"
      ]
     },
     "execution_count": 23,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "torch.log(nn.Softmax(dim=1)(logits.view(-1, logits.size(-1))))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "id": "7ed5f31f-941a-4670-a143-7fb6adbdf567",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "-2.520633333333333"
      ]
     },
     "execution_count": 26,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "((-2.2334 + (-3.1059) + (-1.9685)) + ((-3.2640) + (-2.9550) + (-1.5970)))/6"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "5ca3d4d8-b37b-456a-99b2-215926daf9bf",
   "metadata": {},
   "source": [
    "### labels 作为 id"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "a3e9a29f-7d67-4fab-8d81-63cfb6ea0d0d",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2024-10-20T04:10:16.214681Z",
     "iopub.status.busy": "2024-10-20T04:10:16.213805Z",
     "iopub.status.idle": "2024-10-20T04:10:16.227020Z",
     "shell.execute_reply": "2024-10-20T04:10:16.225787Z",
     "shell.execute_reply.started": "2024-10-20T04:10:16.214636Z"
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[ 0.9580,  1.3221,  0.8172, -0.7658, -0.7506],\n",
       "        [ 1.3525,  0.6863, -0.3278,  0.7950,  0.2815],\n",
       "        [ 0.0562,  0.5227, -0.2384, -0.0499,  0.5263]])"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "x = torch.randn(3, 5)\n",
    "x"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "11949d0c-2aab-4475-8c9d-4ba2d919631d",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2024-10-20T04:12:30.597870Z",
     "iopub.status.busy": "2024-10-20T04:12:30.596461Z",
     "iopub.status.idle": "2024-10-20T04:12:30.603816Z",
     "shell.execute_reply": "2024-10-20T04:12:30.602555Z",
     "shell.execute_reply.started": "2024-10-20T04:12:30.597823Z"
    }
   },
   "outputs": [],
   "source": [
    "labels = torch.tensor([1, 2, 3])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "id": "833a19ee-f6c6-419a-8362-abf20faa0f2b",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2024-10-20T04:12:32.951274Z",
     "iopub.status.busy": "2024-10-20T04:12:32.950001Z",
     "iopub.status.idle": "2024-10-20T04:12:32.958996Z",
     "shell.execute_reply": "2024-10-20T04:12:32.958037Z",
     "shell.execute_reply.started": "2024-10-20T04:12:32.951226Z"
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor(1.8159)"
      ]
     },
     "execution_count": 14,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "F.cross_entropy(x, labels)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "4e1c6b26-e073-4587-83a4-03629e49e1af",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2024-10-20T04:12:03.174147Z",
     "iopub.status.busy": "2024-10-20T04:12:03.172890Z",
     "iopub.status.idle": "2024-10-20T04:12:03.186358Z",
     "shell.execute_reply": "2024-10-20T04:12:03.185208Z",
     "shell.execute_reply.started": "2024-10-20T04:12:03.174111Z"
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[-1.2995, -0.9353, -1.4403, -3.0233, -3.0081],\n",
       "        [-0.9613, -1.6276, -2.6417, -1.5189, -2.0324],\n",
       "        [-1.7646, -1.2980, -2.0591, -1.8706, -1.2944]])"
      ]
     },
     "execution_count": 12,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "torch.log(F.softmax(x, dim=1))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "id": "fe714b5c-1843-43d3-a0ba-d95570e06c96",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2024-10-20T04:12:58.270213Z",
     "iopub.status.busy": "2024-10-20T04:12:58.268656Z",
     "iopub.status.idle": "2024-10-20T04:12:58.277120Z",
     "shell.execute_reply": "2024-10-20T04:12:58.275879Z",
     "shell.execute_reply.started": "2024-10-20T04:12:58.270142Z"
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "-1.8158666666666665"
      ]
     },
     "execution_count": 15,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "(-0.9353 + (-2.6417) + (-1.8706))/3"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b53acb9e-2039-4c73-aa2b-35ce339d8254",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
